/*
 * Copyright (c) 2022 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
import Unsafe from './Unsafe';
import Arrays from '../util/Arrays';
import IllegalArgumentException from '../util/IllegalArgumentException';
import Huffman from './Huffman'
import NodeTable from './NodeTable'
import Util from './Util'
import HuffmanCompressionTableWorkspace from './HuffmanCompressionTableWorkspace'
import BitOutputStream from './BitOutputStream'
import Long from '../util/long/index'
import HuffmanTableWriterWorkspace from './HuffmanTableWriterWorkspace'
import Histogram from './Histogram'
import { FiniteStateEntropy } from './FiniteStateEntropy'
import FseCompressionTable from './FseCompressionTable'
import NumberTransform from './NumberTransform'

export default class HuffmanCompressionTable {
    private values: Int16Array;
    private numberOfBits: Int8Array;
    private maxSymbol: number = 0;
    private maxNumberOfBits: number;

    constructor(capacity: number) {
        this.values = new Int16Array(capacity);
        this.numberOfBits = new Int8Array(capacity);
    }

    public static optimalNumberOfBits(maxNumberOfBits: number, inputSize: number, maxSymbol: number): number
    {
        if (inputSize <= 1) {
            throw new IllegalArgumentException(); // not supported. Use RLE instead
        }

        let result: number = maxNumberOfBits;

        result = Math.min(result, Util.highestBit((inputSize - 1)) - 1); // we may be able to reduce accuracy if input is small

        result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));

        result = Math.max(result, Huffman.MIN_TABLE_LOG); // absolute minimum for Huffman
        result = Math.min(result, Huffman.MAX_TABLE_LOG); // absolute maximum for Huffman

        return result;
    }

    public initialize(counts: Int32Array, maxSymbol: number, maxNumberOfBits: number, workspace: HuffmanCompressionTableWorkspace): void
    {
        Util.checkArgument(maxSymbol <= Huffman.MAX_SYMBOL, "Max symbol value too large");

        workspace.reset();

        let nodeTable: NodeTable = workspace.nodeTable;
        nodeTable.reset();

        let lastNonZero: number = this.buildTree(counts, maxSymbol, nodeTable);

        maxNumberOfBits = HuffmanCompressionTable.setMaxHeight(nodeTable, lastNonZero, maxNumberOfBits, workspace);
        Util.checkArgument(maxNumberOfBits <= Huffman.MAX_TABLE_LOG, "Max number of bits larger than max table size");

        let symbolCount: number = maxSymbol + 1;
        for (let node: number = 0; node < symbolCount; node++) {
            let symbols: number = nodeTable.symbols[node];
            this.numberOfBits[symbols] = nodeTable.numberOfBits[node];
        }

        let entriesPerRank: Int16Array = workspace.entriesPerRank;
        let valuesPerRank: Int16Array = workspace.valuesPerRank;

        for (let n: number = 0; n <= lastNonZero; n++) {
            entriesPerRank[nodeTable.numberOfBits[n]]++;
        }

        let startingValue: number = 0;
        for (let rank: number = maxNumberOfBits; rank > 0; rank--) {
            valuesPerRank[rank] = startingValue; // get starting value within each rank
            startingValue += entriesPerRank[rank];
            startingValue >>>= 1;
        }

        for (let n: number = 0; n <= maxSymbol; n++) {
            this.values[n] = valuesPerRank[this.numberOfBits[n]]++; // assign value within rank, symbol order
        }

        this.maxSymbol = maxSymbol;
        this.maxNumberOfBits = maxNumberOfBits;
    }

    private buildTree(counts: Int32Array, maxSymbol: number, nodeTable: NodeTable): number
    {
        let current: number = 0;

        for (let symbolnum: number = 0; symbolnum <= maxSymbol; symbolnum++) {
            let count: number = counts[symbolnum];

            let position: number = current;
            while (position > 1 && count > nodeTable.count[position - 1]) {
                nodeTable.copyNode(position - 1, position);
                position--;
            }

            nodeTable.count[position] = count;
            nodeTable.symbols[position] = symbolnum;

            current++;
        }

        let lastNonZero: number = maxSymbol;
        while (nodeTable.count[lastNonZero] == 0) {
            lastNonZero--;
        }
        let nonLeafStart: number = Huffman.MAX_SYMBOL_COUNT;
        current = nonLeafStart;

        let currentLeaf: number = lastNonZero;

        let currentNonLeaf: number = current;
        nodeTable.count[current] = nodeTable.count[currentLeaf] + nodeTable.count[currentLeaf - 1];
        nodeTable.parents[currentLeaf] = current;
        nodeTable.parents[currentLeaf - 1] = current;
        current++;
        currentLeaf -= 2;

        let root: number = Huffman.MAX_SYMBOL_COUNT + lastNonZero - 1;

        for (let n: number = current; n <= root; n++) {
            nodeTable.count[n] = 1 << 30;
        }

        while (current <= root) {
            let child1: number;
            if (currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
                child1 = currentLeaf--;
            }
            else {
                child1 = currentNonLeaf++;
            }

            let child2: number;
            if (currentLeaf >= 0 && nodeTable.count[currentLeaf] < nodeTable.count[currentNonLeaf]) {
                child2 = currentLeaf--;
            }
            else {
                child2 = currentNonLeaf++;
            }

            nodeTable.count[current] = nodeTable.count[child1] + nodeTable.count[child2];
            nodeTable.parents[child1] = current;
            nodeTable.parents[child2] = current;
            current++;
        }

        nodeTable.numberOfBits[root] = 0;
        for (let n: number = root - 1; n >= nonLeafStart; n--) {
            let parent: number = nodeTable.parents[n];
            nodeTable.numberOfBits[n] = NumberTransform.toByte(nodeTable.numberOfBits[parent] + 1);
        }

        for (let n: number = 0; n <= lastNonZero; n++) {
            let parent: number = nodeTable.parents[n];
            nodeTable.numberOfBits[n] = NumberTransform.toByte(nodeTable.numberOfBits[parent] + 1);
        }

        return lastNonZero;
    }

    public encodeSymbol(output: BitOutputStream, symbol: number): void
    {
        output.addBitsFast(this.values[symbol], this.numberOfBits[symbol]);
    }

    public write(outputBase: any, outputAddress: Long, outputSize: number, workspace: HuffmanTableWriterWorkspace): number
    {
        let weights: Int8Array = workspace.weights;

        let output: Long = outputAddress;

        let maxNumberOfBits: number = this.maxNumberOfBits;
        let maxSymbol: number = this.maxSymbol;

        for (let symbolnum: number = 0; symbolnum < maxSymbol; symbolnum++) {
            let bits: number = this.numberOfBits[symbolnum];

            if (bits == 0) {
                weights[symbolnum] = 0;
            }
            else {
                weights[symbolnum] = maxNumberOfBits + 1 - bits;
            }
        }

        let size: number = HuffmanCompressionTable.compressWeights(outputBase, output.add(1), outputSize - 1, weights, maxSymbol, workspace);

        if (maxSymbol > 127 && size > 127) {
            throw new Error();
        }

        if (size != 0 && size != 1 && size < maxSymbol / 2) {
            Unsafe.putByte(outputBase, output.toNumber(), size);
            return size + 1; // header + size
        }
        else {

            let entryCount: number = maxSymbol;

            size = (entryCount + 1) / 2; // ceil(#entries / 2)
            Util.checkArgument(size + 1 /* header */
            <= outputSize, "Output size too small"); // 2 entries per byte


            Unsafe.putByte(outputBase, output.toNumber(), 127 + entryCount);
            output = output.add(1);

            weights[maxSymbol] = 0; // last weight is implicit, so set to 0 so that it doesn't get encoded below
            for (let i: number = 0; i < entryCount; i += 2) {
                Unsafe.putByte(outputBase, output.toNumber(), (weights[i] << 4) + weights[i + 1]);
                output = output.add(1);
            }

            return output.sub(outputAddress).toInt();
        }
    }

    public isValid(counts: Int32Array, maxSymbol: number): boolean
    {
        if (maxSymbol > this.maxSymbol) {
            return false;
        }

        for (let symbol: number = 0; symbol <= maxSymbol; ++symbol) {
            if (counts[symbol] != 0 && this.numberOfBits[symbol] == 0) {
                return false;
            }
        }
        return true;
    }

    public estimateCompressedSize(counts: Int32Array, maxSymbol: number): number
    {
        let numberOfBits: number = 0;
        for (let symbol: number = 0; symbol <= Math.min(maxSymbol, this.maxSymbol); symbol++) {
            numberOfBits += this.numberOfBits[symbol] * counts[symbol];
        }

        return numberOfBits >>> 3;
    }

    private static setMaxHeight(nodeTable: NodeTable, lastNonZero: number, maxNumberOfBits: number, workspace: HuffmanCompressionTableWorkspace): number
    {
        let largestBits: number = nodeTable.numberOfBits[lastNonZero];

        if (largestBits <= maxNumberOfBits) {
            return largestBits; // early exit: no elements > maxNumberOfBits
        }

        let totalCost: number = 0;
        let baseCost: number = 1 << (largestBits - maxNumberOfBits);
        let n: number = lastNonZero;

        while (nodeTable.numberOfBits[n] > maxNumberOfBits) {
            totalCost += baseCost - (1 << (largestBits - nodeTable.numberOfBits[n]));
            nodeTable.numberOfBits[n] = maxNumberOfBits;
            n--;
        }

        while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
            n--; // n ends at index of smallest symbol using < maxNumberOfBits
        }

        totalCost >>>= (largestBits - maxNumberOfBits); // note: totalCost is necessarily a multiple of baseCost

        let noSymbol: number = 0xF0F0F0F0;
        let rankLast: Int32Array = workspace.rankLast;
        Arrays.fill(rankLast, noSymbol);

        let currentNbBits: number = maxNumberOfBits;
        for (let pos: number = n; pos >= 0; pos--) {
            if (nodeTable.numberOfBits[pos] >= currentNbBits) {
                continue;
            }
            currentNbBits = nodeTable.numberOfBits[pos]; // < maxNumberOfBits
            rankLast[maxNumberOfBits - currentNbBits] = pos;
        }

        while (totalCost > 0) {
            let numberOfBitsToDecrease: number = Util.highestBit(totalCost) + 1;
            for (; numberOfBitsToDecrease > 1; numberOfBitsToDecrease--) {
                let highPosition: number = rankLast[numberOfBitsToDecrease];
                let lowPosition: number = rankLast[numberOfBitsToDecrease - 1];
                if (highPosition == noSymbol) {
                    continue;
                }
                if (lowPosition == noSymbol) {
                    break;
                }
                let highTotal: number = nodeTable.count[highPosition];
                let lowTotal: number = 2 * nodeTable.count[lowPosition];
                if (highTotal <= lowTotal) {
                    break;
                }
            }

            while ((numberOfBitsToDecrease <= Huffman.MAX_TABLE_LOG) && (rankLast[numberOfBitsToDecrease] == noSymbol)) {
                numberOfBitsToDecrease++;
            }
            totalCost -= 1 << (numberOfBitsToDecrease - 1);
            if (rankLast[numberOfBitsToDecrease - 1] == noSymbol) {
                rankLast[numberOfBitsToDecrease - 1] = rankLast[numberOfBitsToDecrease]; // this rank is no longer empty
            }
            nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]]++;
            if (rankLast[numberOfBitsToDecrease] == 0) { /* special case, reached largest symbol */
                rankLast[numberOfBitsToDecrease] = noSymbol;
            }
            else {
                rankLast[numberOfBitsToDecrease]--;
                if (nodeTable.numberOfBits[rankLast[numberOfBitsToDecrease]] != maxNumberOfBits - numberOfBitsToDecrease) {
                    rankLast[numberOfBitsToDecrease] = noSymbol; // this rank is now empty
                }
            }
        }

        while (totalCost < 0) { // Sometimes, cost correction overshoot
            if (rankLast[1] == noSymbol) { /* special case : no rank 1 symbol (using maxNumberOfBits-1); let's create one from largest rank 0 (using maxNumberOfBits) */
                while (nodeTable.numberOfBits[n] == maxNumberOfBits) {
                    n--;
                }
                nodeTable.numberOfBits[n + 1]--;
                rankLast[1] = n + 1;
                totalCost++;
                continue;
            }
            nodeTable.numberOfBits[rankLast[1] + 1]--;
            rankLast[1]++;
            totalCost++;
        }

        return maxNumberOfBits;
    }

    private static compressWeights(outputBase: any, outputAddress: Long, outputSize: number, weights: Int8Array, weightsLength: number, workspace: HuffmanTableWriterWorkspace): number
    {
        if (weightsLength <= 1) {
            return 0; // Not compressible
        }

        // Scan input and build symbol stats
        let counts: Int32Array = workspace.counts;
        Histogram.count(weights, weightsLength, counts);
        let maxSymbol: number = Histogram.findMaxSymbol(counts, Huffman.MAX_TABLE_LOG);
        let maxCount: number = Histogram.findLargestCount(counts, maxSymbol);

        if (maxCount == weightsLength) {
            return 1; // only a single symbol in source
        }
        if (maxCount == 1) {
            return 0; // each symbol present maximum once => not compressible
        }

        let normalizedCounts: Int16Array = workspace.normalizedCounts;

        let tableLog = FiniteStateEntropy.optimalTableLog(Huffman.MAX_FSE_TABLE_LOG, weightsLength, maxSymbol);
        FiniteStateEntropy.normalizeCounts(normalizedCounts, tableLog, counts, weightsLength, maxSymbol);

        let output: Long = outputAddress;
        let outputLimit: Long = outputAddress.add(outputSize);

        let headerSize: number = FiniteStateEntropy.writeNormalizedCounts(outputBase, output, outputSize, normalizedCounts, maxSymbol, tableLog);
        output = output.add(headerSize);


        let compressionTable: FseCompressionTable = workspace.fseTable;
        compressionTable.initialize(normalizedCounts, maxSymbol, tableLog);
        let compressedSize: number = FiniteStateEntropy.compress(outputBase, output,
        outputLimit.subtract(output).toInt(), weights, weightsLength, compressionTable);
        if (compressedSize == 0) {
            return 0;
        }
        output = output.add(compressedSize);

        return output.subtract(outputAddress).toInt();
    }
}
